import sys
import logging

import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
import copy

from humanfriendly.terminal import output

sys.path.append("../")
from ide_methods.modules.snnide_vtg_multilayer_module import SNNIDEBERTSpikingMultiLayerModule
from ide_methods.modules.snn_modules import SNNFC, SNNBERTSpikingLIFFuncMultiLayer
from model.univtg_original import build_position_encoding
from ide_methods.modules.snn_vtg_modules import VTGSaliencyPool, BertConfig, BertSelfAttentionSplit, BertSelfOutput, BertIntermediate, BertOutput, BertEmbeddings, BertPooler, BertLayerNorm, BertFC
logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, surrogate
from spikingjelly.activation_based import functional
from mamba_ssm import Mamba


def normalize_scores(scores):
    """
    Normalize the saliency scores to the range [0, 1] using min-max normalization.

    :param scores: Tensor of saliency scores (Batch, L_v)
    :return: Normalized scores (Batch, L_v)
    """
    # Ensure the scores are of type float for precision
    scores = scores.float()

    # Compute min and max values along the sequence length dimension
    min_scores = scores.min(dim=1, keepdim=True)[0]  # (Batch, 1)
    max_scores = scores.max(dim=1, keepdim=True)[0]  # (Batch, 1)

    # Apply min-max normalization
    normalized_scores = (scores - min_scores) / (
                max_scores - min_scores + 1e-8)  # Adding small epsilon to avoid division by zero

    return normalized_scores#

class ConvSpiking(nn.Module):
    """ Convolutional network with LIF neurons for activation """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size, use_cupy=False, T = 4, dilation = 1):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.T = T
        # Define the convolutional layers
        # if kernel_size == 3:slms

        # self.mamba = Mamba(
        #     d_model=hidden_dim,
        #     d_state=64,
        #     d_conv=9,
        #     expand=2
        # )
        self.conv_fc = nn.Sequential(layer.Conv1d(input_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, dilation=dilation, groups=1,bias=True), neuron.IFNode(surrogate_function=surrogate.ATan()), layer.Conv1d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, dilation=dilation, groups=1,bias=True), neuron.IFNode(surrogate_function=surrogate.ATan()), layer.Conv1d(hidden_dim, output_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, dilation=dilation, groups=1,bias=True))
        # else:
        #     self.conv_fc = nn.Sequential(layer.Conv1d(input_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=(kernel_size) // 2, dilation=dilation, groups=1,bias=True), neuron.IFNode(surrogate_function=surrogate.ATan()), layer.Conv1d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, dilation=dilation, groups=1,bias=True), neuron.IFNode(surrogate_function=surrogate.ATan()), layer.Conv1d(hidden_dim, hidden_dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, dilation=dilation, groups=1,bias=True), neuron.IFNode(surrogate_function=surrogate.ATan()), layer.Conv1d(hidden_dim, output_dim, kernel_size=kernel_size, stride=1, padding=(kernel_size) // 2, dilation=dilation, groups=1,bias=True))

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x):
        #x = self.mamba(x)
        x = x.permute(0, 2, 1)  # Change shape from (batch_size, seq_len, input_dim) to (batch_size, input_dim, seq_len)
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr.permute(0, 2, 1)  # Change shape back to (batch_size, seq_len, output_dim)


class LinearLayer(nn.Module):
    """linear layer configurable with layer normalization, dropout, ReLU."""

    def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
        super(LinearLayer, self).__init__()
        self.relu = relu
        self.layer_norm = layer_norm
        if layer_norm:
            self.LayerNorm = nn.LayerNorm(in_hsz)
        layers = [
            nn.Dropout(dropout),
            nn.Linear(in_hsz, out_hsz)
        ]
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        """(N, L, D)"""
        if self.layer_norm:
            x = self.LayerNorm(x)
        x = self.net(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x  # (N, L, D)

def mask_logits(inputs, mask, mask_value=-1e30):
    mask = mask.type(torch.float32)
    return inputs + (1.0 - mask) * mask_value

class WeightedPool(nn.Module):
    def __init__(self, dim):
        super(WeightedPool, self).__init__()
        weight = torch.empty(dim, 1)
        nn.init.xavier_uniform_(weight)
        self.weight = nn.Parameter(weight, requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(1))
    def forward(self, x, mask):
        alpha = torch.tensordot(x, self.weight, dims=1)  # shape = (batch_size, seq_length, 1)
        alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
        alphas = nn.Softmax(dim=1)(alpha)
        pooled_x = torch.matmul(x.transpose(1, 2), alphas) #+ self.bias  # (batch_size, dim, 1)
        pooled_x = pooled_x.squeeze(2)
        return pooled_x


class ImprovedWeightedPool(nn.Module):
    def __init__(self, dim):
        super(ImprovedWeightedPool, self).__init__()
        self.dim = dim
        self.weight = nn.Parameter(torch.empty(dim, 1))
        nn.init.xavier_uniform_(self.weight)

        # Adding a bias term for potentially better learning
        self.bias = nn.Parameter(torch.zeros(1))

        # Layer normalization
        self.layer_norm = nn.LayerNorm(dim)

    def forward(self, x, mask):
        """
        Apply weighted pooling to the input tensor.

        :param x: Tensor of shape (batch_size, seq_length, dim)
        :param mask: Mask tensor of shape (batch_size, seq_length), 0 for valid and 1 for padding
        :return: Pooled tensor of shape (batch_size, dim)
        """
        # Apply layer normalization before computing attention
        x = self.layer_norm(x)

        # Compute attention scores
        alpha = torch.tensordot(x, self.weight, dims=1)  # (batch_size, seq_length, 1)
        alpha = alpha + self.bias  # Adding bias
        alpha = alpha.squeeze(-1)  # (batch_size, seq_length)

        # Apply masking to attention scores
        alpha = alpha - (mask * 1e9)  # Large negative value for masked positions

        # Compute attention weights
        alphas = F.softmax(alpha, dim=1)  # (batch_size, seq_length)

        # Compute weighted sum of the input tensor
        pooled_x = torch.bmm(alphas.unsqueeze(1), x)  # (batch_size, 1, dim)
        pooled_x = pooled_x.squeeze(1)  # (batch_size, dim)

        return pooled_x
class ConvUpdated(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        kernel_sizes = [5, 5, 3]

        layers = []

        for i, (n, k) in enumerate(zip([input_dim] + h, h + [output_dim])):
            if i < len(kernel_sizes):
                kernel_size = kernel_sizes[i]
            else:
                kernel_size = 3  # Default kernel size if not specified

            # Assuming 1-dimensional convolution
            conv_layer = nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
            layers.append(conv_layer)

        self.layers = nn.ModuleList(layers)
    def forward(self, x):
        x = x.permute(0,2,1)
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x.permute(0, 2, 1)

class Conv(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)

        #self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
        self.layers = nn.ModuleList(
            nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
                                    for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        x = x.permute(0,2,1)
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x.permute(0, 2, 1)

class SNNIDESTUDENTVTGMODEL(nn.Module):

    def __init__(self, cfg_path, position_embed, txt_position_embed, txt_dim, vid_dim,
                 input_dropout, load_pretrained = False, t_conv = 100, vth = 1., isEval=False, aux_loss=False,
                 max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2,**kwargs):
        super(SNNIDESTUDENTVTGMODEL, self).__init__()

        # Hyperparameters
        self.threshold = 30
        self.time_step = t_conv
        self.vth = vth
        self.dropout = 0.0
        self.leaky = 1. #1 means IF neuron; 0<self.leaky<1 mean LIF
        self.solver = 'broy'
        # self.num_classes = num_classes

        config = BertConfig.from_json_file(cfg_path + '/config.json')
        self.config = config
        self.network_x = nn.Identity()
        # Input to all layers are spikes
        self.network_s1 = BertFC(config)
        self.network_s2 = BertFC(config)
        self.network_s3 = BertSelfAttentionSplit(config)
        self.network_s4 = BertSelfOutput(config)
        self.network_s5 = BertIntermediate(config)
        self.network_s6 = BertOutput(config)
        self.network_s7 = BertFC(config)
        self.network_s8 = BertFC(config)
        self.network_s9 = BertSelfAttentionSplit(config)
        self.network_s10 = BertSelfOutput(config)
        self.network_s11 = BertIntermediate(config)
        self.network_s12 = BertOutput(config)
        self.network_s13 = BertFC(config)
        self.network_s14 = BertFC(config)
        self.network_s15 = BertSelfAttentionSplit(config)
        self.network_s16 = BertSelfOutput(config)
        self.network_s17 = BertIntermediate(config)
        self.network_s18 = BertOutput(config)
        self.network_s19 = BertFC(config)
        self.network_s20 = BertFC(config)
        self.network_s21 = BertSelfAttentionSplit(config)
        self.network_s22 = BertSelfOutput(config)
        self.network_s23 = BertIntermediate(config)
        self.network_s24 = BertOutput(config)

        # Feedback: optional
        self.network_s25 = VTGSaliencyPool(config) #SNNFC(config.hidden_size,config.hidden_size)

        self.snn_func = SNNBERTSpikingLIFFuncMultiLayer(nn.ModuleList([self.network_s1, self.network_s2, self.network_s3, self.network_s4, self.network_s5, self.network_s6, self.network_s7, self.network_s8, self.network_s9, self.network_s10, self.network_s11, self.network_s12, self.network_s13, self.network_s14, self.network_s15, self.network_s16, self.network_s17, self.network_s18, self.network_s19, self.network_s20, self.network_s21, self.network_s22, self.network_s23, self.network_s24, self.network_s25]), self.network_x, vth=self.vth, leaky=self.leaky)

        self.snn_func_copy = copy.deepcopy(self.snn_func)

        for param in self.snn_func_copy.parameters():
            param.requires_grad_(False)

        self.snn_ide_conv = SNNIDEBERTSpikingMultiLayerModule(self.snn_func, self.snn_func_copy)
        self.fit_dense = nn.Linear(config.hidden_size, config.hidden_size) # Dummy for easy weight copy


        self.position_embed = position_embed
        self.txt_position_embed = txt_position_embed
        hidden_dim = config.hidden_size
        self.span_loss_type = span_loss_type
        self.max_v_l = max_v_l
        span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2

        self.token_type_embeddings = nn.Embedding(2, hidden_dim)
        self.token_type_embeddings.apply(self.init_bert_weights)

        # Conv projector
        self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
        self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3)  # 0: background, 1: foreground

        self.use_txt_pos = use_txt_pos
        self.n_input_proj = n_input_proj
        relu_args = [True] * 3
        relu_args[n_input_proj - 1] = False
        self.input_txt_proj = nn.Sequential(*[
                                                 LinearLayer(txt_dim, hidden_dim, layer_norm=True,
                                                             dropout=input_dropout, relu=relu_args[0]),
                                                 LinearLayer(hidden_dim, hidden_dim, layer_norm=True,
                                                             dropout=input_dropout, relu=relu_args[1]),
                                                 LinearLayer(hidden_dim, hidden_dim, layer_norm=True,
                                                             dropout=input_dropout, relu=relu_args[2])
                                             ][:n_input_proj])
        self.input_vid_proj = nn.Sequential(*[
                                                 LinearLayer(vid_dim, hidden_dim, layer_norm=True,
                                                             dropout=input_dropout, relu=relu_args[0]),
                                                 LinearLayer(hidden_dim, hidden_dim, layer_norm=True,
                                                             dropout=input_dropout, relu=relu_args[1]),
                                                 LinearLayer(hidden_dim, hidden_dim, layer_norm=True,
                                                             dropout=input_dropout, relu=relu_args[2])
                                             ][:n_input_proj])

        # MLP Projector
        self.weightedpool = WeightedPool(hidden_dim) #WeightedPool(hidden_dim)
        self.weightedpool_memory = WeightedPool(hidden_dim)
        # self.count = 1
        # self.thresh = .1
        if load_pretrained:
            print('Checkpoint loaded!')
            self.from_pretrained(cfg_path + '/pytorch_model.bin')
        else:
            self.apply(self.init_bert_weights)

        self.span_embed = ConvSpiking(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=9)
        self.class_embed = ConvSpiking(hidden_dim, hidden_dim, 1, 3, kernel_size=3)

        weight = torch.empty(1, 1)
        nn.init.xavier_uniform_(weight)
        self.linear_proj = nn.Linear(75,75)

    def init_bert_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(
                mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def from_pretrained(self, model_path):
        state_dict = torch.load(model_path, map_location='cpu')
        self.load_state_dict(state_dict=state_dict, strict=False)
        return

    def _forward(self, input_ids, segment_ids, attention_mask, vid_shape = 75,  **kwargs):
        threshold = kwargs.get('threshold', self.threshold)
        time_step = kwargs.get('time_step', self.time_step)
        input_type = kwargs.get('input_type', 'constant')
        leaky = kwargs.get('leaky', self.leaky)

        student_rep, atts_avg = self.snn_ide_conv(input_ids, segment_ids, attention_mask = attention_mask, time_step=time_step, threshold=threshold, input_type=input_type, solver_type=self.solver, leaky=leaky, vid_shape = vid_shape)

        return student_rep, atts_avg

    def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, **kwargs):
        functional.reset_net(self.span_embed)
        functional.reset_net(self.class_embed)
        #self.count += 1
        device_id = src_vid.device

        src_vid = self.input_vid_proj(src_vid)
        src_txt = self.input_txt_proj(src_txt)

        device_id = src_vid.device
        # type token.
        src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
        src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))

        # Original
        vid_mem_proj = src_vid
        txt_mem_proj = src_txt
        # # word-level -> sentence-level
        txt_mem_proj = self.weightedpool(txt_mem_proj, src_txt_mask).unsqueeze(1)
        sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()


        # if self.count >= 800:
        #     src_vid_mask = normalize_scores(sim) > min(self.thresh, .6)
        #     if self.count % 500 == 0:
        #         self.thresh += .1
        #         print('Threshold Updated!')


        # Newly added (Mask same as saliency score)
        # vid_mem_proj = src_vid
        # txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
        # src_vid_mask = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) > 0.05 #+ (src_vid_mask + 1e-45).log()
        # End newly added

        src = torch.cat([src_vid, src_txt], dim=1)  # (bsz, L_vid+L_txt, d)
        mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool()  # (bsz, L_vid+L_txt)

        pos_vid = self.position_embed(src_vid, src_vid_mask)  # (bsz, L_vid, d)
        pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt)  # (bsz, L_txt, d)
        pos = torch.cat([pos_vid, pos_txt], dim=1)

        memory, atts_avg = self._forward(src, segment_ids = None, attention_mask = mask, position = pos, vid_shape = src_vid.shape[1], **kwargs) #attention_mask = mask



        final_memory = memory[-1]
        vid_mem = final_memory[:, :src_vid.shape[1], :]  # (bsz, L_vid, d) eg values (4, 92, 1024)

        outputs_class = self.class_embed(vid_mem).sigmoid()  # (#layers, batch_size, #queries, #classes)
        outputs_coord = self.span_embed(vid_mem)  # (#layers, bsz, #queries, 2 or max_v_l * 2)

        if self.span_loss_type == "l1":
            outputs_coord = outputs_coord.sigmoid()
            idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).to(device_id)
            idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
            outputs_coord = outputs_coord * idx_mask
        else:
            raise NotImplementedError

        # New
        # vid_mem_proj = final_memory[:, :src_vid.shape[1],:]
        # txt_mem_proj = src_txt
        # # # word-level -> sentence-level
        # txt_mem_proj = self.weightedpool(txt_mem_proj, src_txt_mask).unsqueeze(1)
        # sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()

        out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
               'src_vid_mask': src_vid_mask}
        out["vid_mem_proj"] = vid_mem_proj
        out["txt_mem_proj"] = txt_mem_proj




        #sim = (sim1 * sim2)

        out["saliency_scores"] = sim

        tmp = []
        for s_id, sequence_layer in enumerate(memory):
            # Fit dense is W_td as described in the paper.
            tmp.append(self.fit_dense(sequence_layer))
        memory = tmp

        return out, memory, atts_avg

def build_spiking_model(args):
    device = torch.device(args.device)

    position_embedding, txt_position_embedding = build_position_encoding(args)
    cfg_path = '/homes/e35839/porcupine/SpikingVTG/UniVTG/spiking-student-model'
    model = SNNIDESTUDENTVTGMODEL(
        cfg_path,
        position_embedding,
        txt_position_embedding,
        txt_dim=args.t_feat_dim,
        vid_dim=args.v_feat_dim,
        input_dropout=args.input_dropout,
        span_loss_type=args.span_loss_type,
        use_txt_pos=args.use_txt_pos,
        n_input_proj=args.n_input_proj,
        t_conv=50,
        load_pretrained=True
    )

    return model


# 2 saliency also uses output
# 1 normal
# TODO asr WITH NORM SCORE
